{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 深度概率编程库"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "MindSpore深度概率编程的目标是将深度学习和贝叶斯学习结合，包括概率分布、概率分布映射、深度概率网络、概率推断算法、贝叶斯层、贝叶斯转换和贝叶斯工具箱，面向不同的开发者。对于专业的贝叶斯学习用户，提供概率采样、推理算法和模型构建库；另一方面，为不熟悉贝叶斯深度学习的用户提供了高级的API，从而不用更改深度学习编程逻辑，即可利用贝叶斯模型。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 概率分布\n",
    "\n",
    "概率分布（`mindspore.nn.probability.distribution`）是概率编程的基础。`Distribution` 类提供多样的概率统计接口，例如概率密度函数 *pdf* 、累积密度函数 *cdf* 、散度计算 *kl_loss* 、抽样 *sample* 等。现有的概率分布实例包括高斯分布，伯努利分布，指数型分布，几何分布和均匀分布。\n",
    "\n",
    "### 概率分布类\n",
    "\n",
    "- `Distribution`：所有概率分布的基类。\n",
    "\n",
    "- `Bernoulli`：伯努利分布。参数为试验成功的概率。\n",
    "\n",
    "- `Exponential`：指数型分布。参数为率参数。\n",
    "\n",
    "- `Geometric`：几何分布。参数为一次伯努利试验成功的概率。\n",
    "\n",
    "- `Normal`：正态（高斯）分布。参数为均值和标准差。\n",
    "\n",
    "- `Uniform`：均匀分布。参数为数轴上的最小值和最大值。\n",
    "\n",
    "- `Categorical`：类别分布。每种类别出现的概率。\n",
    "\n",
    "- `LogNormal`：对数正态分布。参数为位置参数和规模参数。\n",
    "\n",
    "- `Gumbel`: 耿贝尔极值分布。参数为位置参数和规模参数。\n",
    "\n",
    "- `Logistic`：逻辑斯谛分布。参数为位置参数和规模参数。\n",
    "\n",
    "- `Cauchy`：柯西分布。参数为位置参数和规模参数。\n",
    "\n",
    "#### Distribution基类\n",
    "\n",
    "`Distribution` 是所有概率分布的基类。\n",
    "\n",
    "接口介绍：`Distribution` 类支持的函数包括 `prob`、`log_prob`、`cdf`、`log_cdf`、`survival_function`、`log_survival`、`mean`、`sd`、`var`、`entropy`、`kl_loss`、`cross_entropy` 和 `sample` 。分布不同，所需传入的参数也不同。只有在派生类中才能使用，由派生类的函数实现决定参数。\n",
    "\n",
    "- `prob` ：概率密度函数（PDF）/ 概率质量函数（PMF）。\n",
    "- `log_prob` ：对数似然函数。\n",
    "- `cdf` ：累积分布函数（CDF）。\n",
    "- `log_cdf` ：对数累积分布函数。\n",
    "- `survival_function` ：生存函数。\n",
    "- `log_survival` ：对数生存函数。\n",
    "- `mean` ：均值。\n",
    "- `sd` ：标准差。\n",
    "- `var` ：方差。\n",
    "- `entropy` ：熵。\n",
    "- `kl_loss` ：Kullback-Leibler 散度。\n",
    "- `cross_entropy` ：两个概率分布的交叉熵。\n",
    "- `sample` ：概率分布的随机抽样。\n",
    "- `get_dist_args` ：概率分布在网络中使用的参数。\n",
    "- `get_dist_type` ：概率分布的类型。\n",
    "\n",
    "#### 伯努利分布(Bernoulli)\n",
    "\n",
    "伯努利分布，继承自 `Distribution` 类。\n",
    "\n",
    "属性:\n",
    "\n",
    "- `Bernoulli.probs`：返回伯努利试验成功的概率，类型为`Tensor`。\n",
    "\n",
    "`Distribution` 基类调用 `Bernoulli` 中私有接口以实现基类中的公有接口。`Bernoulli` 支持的公有接口为：\n",
    "\n",
    "- `mean`，`mode`，`var`：可选择传入 试验成功的概率 *probs1* 。\n",
    "- `entropy`：可选择传入 试验成功的概率 *probs1* 。\n",
    "- `cross_entropy`，`kl_loss`：必须传入 *dist* 和 *probs1_b* 。*dist* 为另一分布的类型，目前只支持此处为 *‘Bernoulli’* 。 *probs1_b* 为分布 *b* 的试验成功概率。可选择传入分布 *a* 的参数 *probs1_a* 。\n",
    "- `prob`，`log_prob`，`cdf`，`log_cdf`，`survival_function`，`log_survival`：必须传入 *value* 。可选择传入试验成功的概率 *probs* 。\n",
    "- `sample`：可选择传入样本形状 *shape* 和试验成功的概率 *probs1* 。\n",
    "- `get_dist_args` ：可选择传入试验成功的概率 *probs*。\n",
    "- `get_dist_type` ：返回 *‘Bernoulli’* 。\n",
    "\n",
    "#### 指数分布(Exponential)\n",
    "\n",
    "指数分布，继承自 `Distribution` 类。\n",
    "\n",
    "属性:\n",
    "\n",
    "- `Exponential.rate`：返回分布的率参数，类型为`Tensor`。\n",
    "\n",
    "`Distribution` 基类调用 `Exponential` 私有接口以实现基类中的公有接口。`Exponential` 支持的公有接口为：\n",
    "\n",
    "- `mean`，`mode`，`var`：可选择传入率参数 *rate* 。\n",
    "- `entropy`：可选择传入率参数 *rate* 。\n",
    "- `cross_entropy`，`kl_loss`：必须传入 *dist* 和 *rate_b* 。 *dist* 为另一分布的类型的名称， 目前只支持此处为 *‘Exponential’* 。*rate_b* 为分布 *b* 的率参数。可选择传入分布 *a* 的参数 *rate_a* 。\n",
    "- `prob`，`log_prob`，`cdf`，`log_cdf`，`survival_function`，`log_survival`：必须传入 *value* 。可选择传入率参数 *rate* 。\n",
    "- `sample`：可选择传入样本形状 *shape* 和率参数 *rate* 。\n",
    "- `get_dist_args` ：可选择传入率参数 *rate* 。\n",
    "- `get_dist_type` ：返回 *‘Exponential’* 。\n",
    "\n",
    "#### 几何分布(Geometric)\n",
    "\n",
    "几何分布，继承自 `Distribution` 类。\n",
    "\n",
    "属性:\n",
    "\n",
    "- `Geometric.probs`：返回伯努利试验成功的概率，类型为`Tensor`。\n",
    "\n",
    "`Distribution` 基类调用 `Geometric` 中私有接口以实现基类中的公有接口。`Geometric` 支持的公有接口为：\n",
    "\n",
    "- `mean`，`mode`，`var`：可选择传入试验成功的概率 *probs1* 。\n",
    "- `entropy`：可选择传入 试验成功的概率 *probs1* 。\n",
    "- `cross_entropy`，`kl_loss`：必须传入 *dist* 和 *probs1_b* 。*dist* 为另一分布的类型的名称，目前只支持此处为 *‘Geometric’* 。 *probs1_b* 为分布 *b* 的试验成功概率。可选择传入分布 *a* 的参数 *probs1_a* 。\n",
    "- `prob`，`log_prob`，`cdf`，`log_cdf`，`survival_function`，`log_survival`：必须传入 *value* 。可选择传入试验成功的概率 *probs1* 。\n",
    "- `sample`：可选择传入样本形状 *shape* 和试验成功的概率 *probs1* 。\n",
    "- `get_dist_args` ：可选择传入试验成功的概率 *probs1* 。\n",
    "- `get_dist_type` ：返回 *‘Geometric’* 。\n",
    "\n",
    "#### 正态分布(Normal)\n",
    "\n",
    "正态（高斯）分布，继承自 `Distribution` 类。\n",
    "\n",
    "`Distribution` 基类调用 `Normal` 中私有接口以实现基类中的公有接口。`Normal` 支持的公有接口为：\n",
    "\n",
    "- `mean`，`mode`，`var`：可选择传入分布的参数均值 *mean* 和标准差 *sd* 。\n",
    "- `entropy`：可选择传入分布的参数均值 *mean* 和标准差 *sd* 。\n",
    "- `cross_entropy`，`kl_loss`：必须传入 *dist* ，*mean_b* 和 *sd_b* 。*dist* 为另一分布的类型的名称，目前只支持此处为 *‘Normal’* 。*mean_b* 和 *sd_b* 为分布 *b* 的均值和标准差。可选择传入分布的参数 *a* 均值 *mean_a* 和标准差 *sd_a* 。\n",
    "- `prob`，`log_prob`，`cdf`，`log_cdf`，`survival_function`，`log_survival`：必须传入 *value* 。可选择分布的参数包括均值 *mean_a* 和标准差 *sd_a* 。\n",
    "- `sample`：可选择传入样本形状 *shape* 和分布的参数包括均值 *mean_a* 和标准差 *sd_a* 。\n",
    "- `get_dist_args` ：可选择传入分布的参数均值 *mean* 和标准差 *sd* 。\n",
    "- `get_dist_type` ：返回 *‘Normal’* 。\n",
    "\n",
    "#### 均匀分布(Uniform)\n",
    "\n",
    "均匀分布，继承自 `Distribution` 类。\n",
    "\n",
    "属性:\n",
    "\n",
    "- `Uniform.low`：返回分布的最小值，类型为`Tensor`。\n",
    "- `Uniform.high`：返回分布的最大值，类型为`Tensor`。\n",
    "\n",
    "`Distribution` 基类调用 `Uniform` 以实现基类中的公有接口。`Uniform` 支持的公有接口为：\n",
    "\n",
    "- `mean`，`mode`，`var`：可选择传入分布的参数最大值 *high* 和最小值 *low* 。\n",
    "- `entropy`：可选择传入分布的参数最大值 *high* 和最小值 *low* 。\n",
    "- `cross_entropy`，`kl_loss`：必须传入 *dist* ，*high_b* 和 *low_b* 。*dist* 为另一分布的类型的名称，目前只支持此处为 *‘Uniform’* 。 *high_b* 和 *low_b* 为分布 *b* 的参数。可选择传入分布 *a* 的参数即最大值 *high_a* 和最小值 *low_a* 。\n",
    "- `prob`，`log_prob`，`cdf`，`log_cdf`，`survival_function`，`log_survival`：必须传入 *value* 。可选择传入分布的参数最大值 *high* 和最小值 *low* 。\n",
    "- `sample`：可选择传入 *shape* 和分布的参数即最大值 *high* 和最小值 *low* 。\n",
    "- `get_dist_args` ：可选择传入分布的参数最大值 *high* 和最小值 *low* 。\n",
    "- `get_dist_type` ：返回 *‘Uniform’* 。\n",
    "\n",
    "#### 多类别分布（Categorical）\n",
    "\n",
    "多类别分布，继承自 `Distribution` 类。\n",
    "\n",
    "属性:\n",
    "\n",
    "- `Categorical.probs`：返回各种类别的概率，类型为`Tensor`。\n",
    "\n",
    "`Distribution` 基类调用 `Categorical` 以实现基类中的公有接口。`Categorical` 支持的公有接口为：\n",
    "\n",
    "- `mean`，`mode`，`var`：可选择传入分布的参数类别概率 *probs*。\n",
    "- `entropy`：可选择传入分布的参数类别概率 *probs* 。\n",
    "- `cross_entropy`，`kl_loss`：必须传入 *dist* ，*probs_b* 。*dist* 为另一分布的类型的名称，目前只支持此处为 *‘Categorical’* 。 *probs_b* 为分布 *b* 的参数。可选择传入分布 *a* 的参数即 *probs_a* 。\n",
    "- `prob`，`log_prob`，`cdf`，`log_cdf`，`survival_function`，`log_survival`：必须传入 *value* 。可选择传入分布的参数类别概率 *probs* 。\n",
    "- `sample`：可选择传入 *shape* 和类别概率 *probs* 。\n",
    "- `get_dist_args` ：可选择传入分布的参数类别概率 *probs* 。\n",
    "- `get_dist_type` ：返回 *‘Categorical’* 。\n",
    "\n",
    "#### 对数正态分布(LogNormal)\n",
    "\n",
    "对数正态分布，继承自 `TransformedDistribution` 类，由 `Exp` Bijector 和 `Normal` Distribution 构成。\n",
    "\n",
    "属性：\n",
    "\n",
    "- `LogNormal.loc`：返回分布的位置参数，类型为`Tensor`。\n",
    "- `LogNormal.scale`：返回分布的规模参数，类型为`Tensor`。\n",
    "\n",
    "`Distribution` 基类调用 `LogNormal`及 `TransformedDistribution` 中私有接口以实现基类中的公有接口。`LogNormal` 支持的公有接口为：\n",
    "\n",
    "- `mean`，`mode`，`var`：可选择传入分布的位置参数*loc*和规模参数*scale* 。\n",
    "- `entropy`：可选择传入分布的位置参数 *loc* 和规模参数 *scale* 。\n",
    "- `cross_entropy`，`kl_loss`：必须传入 *dist* ，*loc_b* 和 *scale_b* 。*dist* 为另一分布的类型的名称，目前只支持此处为 *‘LogNormal’* 。*loc_b* 和 *scale_b* 为分布 *b* 的均值和标准差。可选择传入分布的参数 *a* 均值 *loc_a* 和标准差 *sclae_a* 。\n",
    "- `prob`，`log_prob`，`cdf`，`log_cdf`，`survival_function`，`log_survival`：必须传入 *value* 。可选择分布的参数包括均值 *loc_a* 和标准差 *scale_a* 。`Distribution` 基类调用 `TransformedDistribution`私有接口。\n",
    "- `sample`：可选择传入样本形状 *shape* 和分布的参数包括均值 *loc_a* 和标准差 *scale_a* 。`Distribution` 基类调用 `TransformedDistribution`私有接口。\n",
    "- `get_dist_args` ：可选择传入分布的位置参数 *loc* 和规模参数*scale* 。\n",
    "- `get_dist_type` ：返回 *‘LogNormal’* 。\n",
    "\n",
    "#### 柯西分布(Cauchy)\n",
    "\n",
    "柯西分布，继承自 `Distribution` 类。\n",
    "\n",
    "属性：\n",
    "\n",
    "- `Cauchy.loc`：返回分布的位置参数，类型为`Tensor`。\n",
    "- `Cauchy.scale`：返回分布的规模参数，类型为`Tensor`。\n",
    "\n",
    "`Distribution` 基类调用 `Cauchy` 中私有接口以实现基类中的公有接口。`Cauchy` 支持的公有接口为：\n",
    "\n",
    "- `entropy`：可选择传入分布的位置参数*loc*和规模参数*scale*。\n",
    "- `cross_entropy`，`kl_loss`：必须传入 *dist* ，*loc_b* 和 *scale_b* 。*dist* 为另一分布的类型的名称，目前只支持此处为 *‘Cauchy’* 。*loc_b* 和 *scale_b* 为分布 *b* 的位置参数和规模参数。可选择传入分布的参数 *a* 位置 *loc_a* 和规模 *scale_a* 。\n",
    "- `prob`，`log_prob`，`cdf`，`log_cdf`，`survival_function`，`log_survival`：必须传入 *value* 。可选择传入分布的位置参数 *loc* 和规模参数 *scale* 。\n",
    "- `sample`：可选择传入样本形状 *shape* 和分布的参数包括分布的位置参数 *loc* 和规模参数 *scale* 。\n",
    "- `get_dist_args` ：可选择传入分布的位置参数 *loc* 和规模参数 *scale* 。\n",
    "- `get_dist_type` ：返回 *‘Cauchy’* 。\n",
    "\n",
    "#### 耿贝尔极值分布(Gumbel)\n",
    "\n",
    "耿贝尔极值分布，继承自 `TransformedDistribution` 类，由 `GumbelCDF` Bijector和 `Uniform` Distribution 构成。\n",
    "\n",
    "属性：\n",
    "\n",
    "- `Gumbel.loc`：返回分布的位置参数，类型为`Tensor`。\n",
    "- `Gumbel.scale`：返回分布的规模参数，类型为`Tensor`。\n",
    "\n",
    "`Distribution` 基类调用 `Gumbel` 中私有接口以实现基类中的公有接口。`Gumbel` 支持的公有接口为：\n",
    "\n",
    "- `mean`，`mode`，`sd`：无参数 。\n",
    "- `entropy`：无参数 。\n",
    "- `cross_entropy`，`kl_loss`：必须传入 *dist* ，*loc_b* 和 *scale_b* 。*dist* 为另一分布的类型的名称，目前只支持此处为 *‘Gumbel’* 。*loc_b* 和 *scale_b* 为分布 *b* 的位置参数和规模参数。\n",
    "- `prob`，`log_prob`，`cdf`，`log_cdf`，`survival_function`，`log_survival`：必须传入 *value* 。\n",
    "- `sample`：可选择传入样本形状 *shape* 。\n",
    "- `get_dist_args` ：可选择传入分布的位置参数 *loc* 和规模参数 *scale* 。\n",
    "- `get_dist_type` ：返回 *‘Gumbel’* 。\n",
    "\n",
    "#### 逻辑斯谛分布(Logistic)\n",
    "\n",
    "逻辑斯谛分布，继承自 `Distribution` 类。\n",
    "\n",
    "属性：\n",
    "\n",
    "- `Logistic.loc`：返回分布的位置参数，类型为`Tensor`。\n",
    "- `Logistic.scale`：返回分布的规模参数，类型为`Tensor`。\n",
    "\n",
    "`Distribution` 基类调用 `logistic` 中私有接口以实现基类中的公有接口。`Logistic` 支持的公有接口为：\n",
    "\n",
    "- `mean`，`mode`，`sd`：可选择传入分布的位置参数 *loc* 和规模参数 *scale* 。\n",
    "- `entropy`：可选择传入分布的位置参数 *loc* 和规模参数 *scale* 。\n",
    "- `prob`，`log_prob`，`cdf`，`log_cdf`，`survival_function`，`log_survival`：必须传入 *value* 。可选择传入分布的位置参数 *loc* 和规模参数 *scale* 。\n",
    "- `sample`：可选择传入样本形状 *shape* 和分布的参数包括分布的位置参数 *loc* 和规模参数 *scale* 。\n",
    "- `get_dist_args` ：可选择传入分布的位置参数 *loc* 和规模参数 *scale* 。\n",
    "- `get_dist_type` ：返回 *‘Logistic’* 。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 概率分布类在PyNative模式下的应用\n",
    "\n",
    "`Distribution` 子类可在 **PyNative** 模式下使用。\n",
    "\n",
    "以 `Normal` 为例， 创建一个均值为0.0、标准差为1.0的正态分布，然后计算相关函数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mean:  0.0\n",
      "var:  1.0\n",
      "entropy:  1.4189385\n",
      "prob:  [0.35206532 0.3989423  0.35206532]\n",
      "cdf:  [0.30853754 0.5        0.69146246]\n",
      "kl:  0.44314718\n"
     ]
    }
   ],
   "source": [
    "from mindspore import Tensor\n",
    "from mindspore import dtype as mstype\n",
    "import mindspore.context as context\n",
    "import mindspore.nn.probability.distribution as msd\n",
    "context.set_context(mode=context.PYNATIVE_MODE, device_target=\"GPU\")\n",
    "\n",
    "my_normal = msd.Normal(0.0, 1.0, dtype=mstype.float32)\n",
    "\n",
    "mean = my_normal.mean()\n",
    "var = my_normal.var()\n",
    "entropy = my_normal.entropy()\n",
    "\n",
    "value = Tensor([-0.5, 0.0, 0.5], dtype=mstype.float32)\n",
    "prob = my_normal.prob(value)\n",
    "cdf = my_normal.cdf(value)\n",
    "\n",
    "mean_b = Tensor(1.0, dtype=mstype.float32)\n",
    "sd_b = Tensor(2.0, dtype=mstype.float32)\n",
    "kl = my_normal.kl_loss('Normal', mean_b, sd_b)\n",
    "\n",
    "print(\"mean: \", mean)\n",
    "print(\"var: \", var)\n",
    "print(\"entropy: \", entropy)\n",
    "print(\"prob: \", prob)\n",
    "print(\"cdf: \", cdf)\n",
    "print(\"kl: \", kl)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 概率分布类在图模式下的应用\n",
    "\n",
    "在图模式下，`Distribution` 子类可用在网络中。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pdf:  [0.3520653  0.39894226 0.3520653 ]\n",
      "kl:  0.5\n"
     ]
    }
   ],
   "source": [
    "import mindspore.nn as nn\n",
    "from mindspore import Tensor\n",
    "from mindspore import dtype as mstype\n",
    "import mindspore.context as context\n",
    "import mindspore.nn.probability.distribution as msd\n",
    "context.set_context(mode=context.GRAPH_MODE)\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.normal = msd.Normal(0.0, 1.0, dtype=mstype.float32)\n",
    "\n",
    "    def construct(self, value, mean, sd):\n",
    "        pdf = self.normal.prob(value)\n",
    "        kl = self.normal.kl_loss(\"Normal\", mean, sd)\n",
    "        return pdf, kl\n",
    "\n",
    "net = Net()\n",
    "value = Tensor([-0.5, 0.0, 0.5], dtype=mstype.float32)\n",
    "mean = Tensor(1.0, dtype=mstype.float32)\n",
    "sd = Tensor(1.0, dtype=mstype.float32)\n",
    "pdf, kl = net(value, mean, sd)\n",
    "print(\"pdf: \", pdf)\n",
    "print(\"kl: \", kl)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TransformedDistribution类接口设计\n",
    "\n",
    "`TransformedDistribution` 继承自 `Distribution` ，是可通过映射f(x)变化得到的数学分布的基类。其接口包括：\n",
    "\n",
    "1. 属性\n",
    "\n",
    "    - `bijector`：返回分布的变换方法。\n",
    "    - `distribution`：返回原始分布。\n",
    "    - `is_linear_transformation`：返回线性变换标志。\n",
    "\n",
    "2. 接口函数（以下接口函数的参数与构造函数中 `distribution` 的对应接口的参数相同）。\n",
    "\n",
    "    - `cdf`：累积分布函数（CDF）。\n",
    "    - `log_cdf`：对数累积分布函数。\n",
    "    - `survival_function`：生存函数。\n",
    "    - `log_survival`：对数生存函数。\n",
    "    - `prob`：概率密度函数（PDF）/ 概率质量函数（PMF）。\n",
    "    - `log_prob`：对数似然函数。\n",
    "    - `sample`：随机取样。\n",
    "    - `mean`：无参数。只有当 `Bijector.is_constant_jacobian=true` 时可调用。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### PyNative模式下调用TransformedDistribution实例\n",
    "\n",
    "`TransformedDistribution` 子类可在 **PyNative** 模式下使用。\n",
    "\n",
    "这里构造一个 `TransformedDistribution` 实例，使用 `Normal` 分布作为需要变换的分布类，使用 `Exp` 作为映射变换，可以生成 `LogNormal` 分布。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TransformedDistribution<\n",
      "  (_bijector): Exp<power = 0.0>\n",
      "  (_distribution): Normal<mean = 0.0, standard deviation = 1.0>\n",
      "  >\n",
      "underlying distribution:\n",
      " Normal<mean = 0.0, standard deviation = 1.0>\n",
      "bijector:\n",
      " Exp<power = 0.0>\n",
      "cdf:\n",
      " [0.7558914 0.9462397 0.9893489]\n",
      "sample:\n",
      " [[2.5509434  0.6857236 ]\n",
      " [0.70752895 0.5937665 ]\n",
      " [0.94011104 0.86582065]]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore.nn as nn\n",
    "import mindspore.nn.probability.bijector as msb\n",
    "import mindspore.nn.probability.distribution as msd\n",
    "import mindspore.context as context\n",
    "from mindspore import Tensor, dtype\n",
    "\n",
    "context.set_context(mode=context.PYNATIVE_MODE)\n",
    "\n",
    "normal = msd.Normal(0.0, 1.0, dtype=dtype.float32)\n",
    "exp = msb.Exp()\n",
    "LogNormal = msd.TransformedDistribution(exp, normal, seed=0, name=\"LogNormal\")\n",
    "\n",
    "# compute cumulative distribution function\n",
    "x = np.array([2.0, 5.0, 10.0], dtype=np.float32)\n",
    "tx = Tensor(x, dtype=dtype.float32)\n",
    "cdf = LogNormal.cdf(tx)\n",
    "\n",
    "# generate samples from the distribution\n",
    "shape = ((3, 2))\n",
    "sample = LogNormal.sample(shape)\n",
    "\n",
    "# get information of the distribution\n",
    "print(LogNormal)\n",
    "# get information of the underlying distribution and the bijector separately\n",
    "print(\"underlying distribution:\\n\", LogNormal.distribution)\n",
    "print(\"bijector:\\n\", LogNormal.bijector)\n",
    "# get the computation results\n",
    "print(\"cdf:\\n\", cdf)\n",
    "print(\"sample:\\n\", sample)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "当构造 `TransformedDistribution` 映射变换的 `is_constant_jacobian = true` 时（如 `ScalarAffine`)，构造的 `TransformedDistribution` 实例可以使用直接使用 `mean` 接口计算均值，例如："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.0\n"
     ]
    }
   ],
   "source": [
    "normal = msd.Normal(0.0, 1.0, dtype=dtype.float32)\n",
    "scalaraffine = msb.ScalarAffine(1.0, 2.0)\n",
    "trans_dist = msd.TransformedDistribution(scalaraffine, normal, seed=0)\n",
    "mean = trans_dist.mean()\n",
    "print(mean)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 图模式下调用TransformedDistribution实例\n",
    "\n",
    "在图模式下，`TransformedDistribution` 类可用在网络中。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cdf:  [0.7558914  0.86403143 0.9171715  0.9462397 ]\n",
      "sample:  [[1.9951994 0.4296873 1.6143396]\n",
      " [1.1043785 5.220887  1.504073 ]]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore.nn as nn\n",
    "from mindspore import Tensor, dtype\n",
    "import mindspore.context as context\n",
    "import mindspore.nn.probability.bijector as msb\n",
    "import mindspore.nn.probability.distribution as msd\n",
    "context.set_context(mode=context.GRAPH_MODE)\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    def __init__(self, shape, dtype=dtype.float32, seed=0, name='transformed_distribution'):\n",
    "        super(Net, self).__init__()\n",
    "        # create TransformedDistribution distribution\n",
    "        self.exp = msb.Exp()\n",
    "        self.normal = msd.Normal(0.0, 1.0, dtype=dtype)\n",
    "        self.lognormal = msd.TransformedDistribution(self.exp, self.normal, seed=seed, name=name)\n",
    "        self.shape = shape\n",
    "\n",
    "    def construct(self, value):\n",
    "        cdf = self.lognormal.cdf(value)\n",
    "        sample = self.lognormal.sample(self.shape)\n",
    "        return cdf, sample\n",
    "\n",
    "shape = (2, 3)\n",
    "net = Net(shape=shape, name=\"LogNormal\")\n",
    "x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)\n",
    "tx = Tensor(x, dtype=dtype.float32)\n",
    "cdf, sample = net(tx)\n",
    "print(\"cdf: \", cdf)\n",
    "print(\"sample: \", sample)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 概率分布映射\n",
    "\n",
    "Bijector（`mindspore.nn.probability.bijector`）是概率编程的基本组成部分。Bijector描述了一种随机变量的变换方法，可以通过一个已有的随机变量X和一个映射函数f生成一个新的随机变量$Y = f(x)$。\n",
    "`Bijector` 提供了映射相关的四种变换方法。它可以当做算子直接使用，也可以作用在某个随机变量 `Distribution` 类实例上生成新的随机变量的 `Distribution` 类实例。\n",
    "\n",
    "### Bijector类接口设计\n",
    "\n",
    "#### Bijector基类\n",
    "\n",
    "`Bijector` 类是所有概率分布映射的基类。其接口包括：\n",
    "\n",
    "1. 属性\n",
    "    - `name`：返回 `name` 的值。\n",
    "    - `is_dtype`：返回 `dtype` 的值。\n",
    "    - `parameter`：返回 `parameter` 的值。\n",
    "    - `is_constant_jacobian`：返回 `is_constant_jacobian` 的值。\n",
    "    - `is_injective`：返回 `is_injective` 的值。\n",
    "\n",
    "2. 映射函数\n",
    "    - `forward`：正向映射，创建派生类后由派生类的 `_forward` 决定参数。\n",
    "    - `inverse`：反向映射，创建派生类后由派生类的 `_inverse` 决定参数。\n",
    "    - `forward_log_jacobian`：正向映射的导数的对数，创建派生类后由派生类的 `_forward_log_jacobian` 决定参数。\n",
    "    - `inverse_log_jacobian`：反向映射的导数的对数，创建派生类后由派生类的 `_inverse_log_jacobian` 决定参数。\n",
    "\n",
    "`Bijector` 作为函数调用：输入是一个 `Distribution` 类：生成一个 `TransformedDistribution` **（不可在图内调用）**。\n",
    "\n",
    "#### 幂函数变换映射(PowerTransform)\n",
    "\n",
    "`PowerTransform` 做如下变量替换：`Y = g(X) = {(1 + X * power)}^{1 / power}`。其接口包括：\n",
    "\n",
    "1. 属性\n",
    "    - `power`：返回 `power` 的值，类型为`Tensor`。\n",
    "\n",
    "2. 映射函数\n",
    "    - `forward`：正向映射，输入为 `Tensor` 。\n",
    "    - `inverse`：反向映射，输入为 `Tensor` 。\n",
    "    - `forward_log_jacobian`：正向映射的导数的对数，输入为 `Tensor` 。\n",
    "    - `inverse_log_jacobian`：反向映射的导数的对数，输入为 `Tensor` 。\n",
    "\n",
    "#### 指数变换映射(Exp)\n",
    "\n",
    "`Exp` 做如下变量替换：`Y = g(X)= exp(X)`。其接口包括：\n",
    "\n",
    "映射函数\n",
    "\n",
    "- `forward`：正向映射，输入为 `Tensor` 。\n",
    "- `inverse`：反向映射，输入为 `Tensor` 。\n",
    "- `forward_log_jacobian`：正向映射的导数的对数，输入为 `Tensor` 。\n",
    "- `inverse_log_jacobian`：反向映射的导数的对数，输入为 `Tensor` 。\n",
    "\n",
    "#### 标量仿射变换映射(ScalarAffine)\n",
    "\n",
    "`ScalarAffine` 做如下变量替换：`Y = g(X) = scale * X + shift`。其接口包括：\n",
    "\n",
    "1. 属性\n",
    "    - `scale`：返回`scale`的值，类型为`Tensor`。\n",
    "    - `shift`：返回`shift`的值，类型为`Tensor`。\n",
    "\n",
    "2. 映射函数\n",
    "    - `forward`：正向映射，输入为 `Tensor` 。\n",
    "    - `inverse`：反向映射，输入为 `Tensor` 。\n",
    "    - `forward_log_jacobian`：正向映射的导数的对数，输入为 `Tensor` 。\n",
    "    - `inverse_log_jacobian`：反向映射的导数的对数，输入为 `Tensor` 。\n",
    "\n",
    "#### Softplus变换映射(Softplus)\n",
    "\n",
    "`Softplus` 做如下变量替换：`Y = g(X) = log(1 + e ^ {sharpness * X}) / sharpness`。其接口包括：\n",
    "\n",
    "1. 属性\n",
    "    - `sharpness`：返回 `sharpness` 的值，类型为`Tensor`。\n",
    "\n",
    "2. 映射函数\n",
    "    - `forward`：正向映射，输入为 `Tensor` 。\n",
    "    - `inverse`：反向映射，输入为 `Tensor` 。\n",
    "    - `forward_log_jacobian`：正向映射的导数的对数，输入为 `Tensor` 。\n",
    "    - `inverse_log_jacobian`：反向映射的导数的对数，输入为 `Tensor` 。\n",
    "\n",
    "#### 耿贝尔累计密度函数映射(GumbelCDF)\n",
    "\n",
    "`GumbelCDF` 做如下变量替换：$Y = g(X) = \\exp(-\\exp(-\\frac{X - loc}{scale}))$。其接口包括：\n",
    "\n",
    "1. 属性\n",
    "    - `loc`：返回`loc`的值，类型为`Tensor`。\n",
    "    - `scale`：返回`scale`的值，类型为`Tensor`。\n",
    "\n",
    "2. 映射函数\n",
    "    - `forward`：正向映射，输入为 `Tensor` 。\n",
    "    - `inverse`：反向映射，输入为 `Tensor` 。\n",
    "    - `forward_log_jacobian`：正向映射的导数的对数，输入为 `Tensor` 。\n",
    "    - `inverse_log_jacobian`：反向映射的导数的对数，输入为 `Tensor` 。\n",
    "\n",
    "#### 逆映射(Invert)\n",
    "\n",
    "`Invert` 对一个映射做逆变换，其接口包括：\n",
    "\n",
    "1. 属性\n",
    "    - `bijector`：返回初始化时使用的*Bijector*，类型为`Bijector`。\n",
    "\n",
    "2. 映射函数\n",
    "    - `forward`：正向映射，输入为 `Tensor` 。\n",
    "    - `inverse`：反向映射，输入为 `Tensor` 。\n",
    "    - `forward_log_jacobian`：正向映射的导数的对数，输入为 `Tensor` 。\n",
    "    - `inverse_log_jacobian`：反向映射的导数的对数，输入为 `Tensor` 。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### PyNative模式下调用Bijector实例\n",
    "\n",
    "在执行之前，我们需要导入需要的库文件包。双射类最主要的库是 `mindspore.nn.probability.bijector`，导入后我们使用 `msb` 作为库的缩写并进行调用。\n",
    "\n",
    "下面我们以 `PowerTransform` 为例。创建一个指数为2的 `PowerTransform` 对象。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PowerTransform<power = 2.0>\n",
      "forward:  [2.236068  2.6457515 3.        3.3166249]\n",
      "inverse:  [ 1.5       4.        7.5      12.000001]\n",
      "forward_log_jacobian:  [-0.804719  -0.9729551 -1.0986123 -1.1989477]\n",
      "inverse_log_jacobian:  [0.6931472 1.0986123 1.3862944 1.609438 ]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore.nn as nn\n",
    "import mindspore.nn.probability.bijector as msb\n",
    "import mindspore.context as context\n",
    "from mindspore import Tensor, dtype\n",
    "\n",
    "context.set_context(mode=context.PYNATIVE_MODE)\n",
    "\n",
    "powertransform = msb.PowerTransform(power=2.)\n",
    "\n",
    "x = np.array([2.0, 3.0, 4.0, 5.0], dtype=np.float32)\n",
    "tx = Tensor(x, dtype=dtype.float32)\n",
    "forward = powertransform.forward(tx)\n",
    "inverse = powertransform.inverse(tx)\n",
    "forward_log_jaco = powertransform.forward_log_jacobian(tx)\n",
    "inverse_log_jaco = powertransform.inverse_log_jacobian(tx)\n",
    "\n",
    "print(powertransform)\n",
    "print(\"forward: \", forward)\n",
    "print(\"inverse: \", inverse)\n",
    "print(\"forward_log_jacobian: \", forward_log_jaco)\n",
    "print(\"inverse_log_jacobian: \", inverse_log_jaco)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 图模式下调用Bijector实例\n",
    "\n",
    "在图模式下，`Bijector` 子类可用在网络中。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "forward:  [2.236068  2.6457515 3.        3.3166249]\n",
      "inverse:  [ 1.5       4.        7.5      12.000001]\n",
      "forward_log_jacobian:  [-0.804719  -0.9729551 -1.0986123 -1.1989477]\n",
      "inverse_log_jacobian:  [0.6931472 1.0986123 1.3862944 1.609438 ]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore.nn as nn\n",
    "from mindspore import Tensor\n",
    "from mindspore import dtype as mstype\n",
    "import mindspore.context as context\n",
    "import mindspore.nn.probability.bijector as msb\n",
    "context.set_context(mode=context.GRAPH_MODE)\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        # create a PowerTransform bijector\n",
    "        self.powertransform = msb.PowerTransform(power=2.)\n",
    "\n",
    "    def construct(self, value):\n",
    "        forward = self.powertransform.forward(value)\n",
    "        inverse = self.powertransform.inverse(value)\n",
    "        forward_log_jaco = self.powertransform.forward_log_jacobian(value)\n",
    "        inverse_log_jaco = self.powertransform.inverse_log_jacobian(value)\n",
    "        return forward, inverse, forward_log_jaco, inverse_log_jaco\n",
    "\n",
    "net = Net()\n",
    "x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)\n",
    "tx = Tensor(x, dtype=mstype.float32)\n",
    "forward, inverse, forward_log_jaco, inverse_log_jaco = net(tx)\n",
    "print(\"forward: \", forward)\n",
    "print(\"inverse: \", inverse)\n",
    "print(\"forward_log_jacobian: \", forward_log_jaco)\n",
    "print(\"inverse_log_jacobian: \", inverse_log_jaco)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 深度概率网络\n",
    "\n",
    "使用MindSpore深度概率编程库（`mindspore.nn.probability.dpn`）来构造变分自编码器（VAE）进行推理尤为简单。我们只需要自定义编码器和解码器（DNN模型），调用VAE或CVAE接口形成其派生网络，然后调用ELBO接口进行优化，最后使用SVI接口进行变分推理。这样做的好处是，不熟悉变分推理的用户可以像构建DNN模型一样来构建概率模型，而熟悉的用户可以调用这些接口来构建更为复杂的概率模型。VAE的接口在`mindspore.nn.probability.dpn`下面，dpn代表的是Deep probabilistic network，这里提供了一些基本的深度概率网络的接口，例如VAE。\n",
    "\n",
    "### VAE\n",
    "\n",
    "首先，我们需要先自定义encoder和decoder，调用`mindspore.nn.probability.dpn.VAE`接口来构建VAE网络，我们除了传入encoder和decoder之外，还需要传入encoder输出变量的维度hidden size，以及VAE网络存储潜在变量的维度latent size，一般latent size会小于hidden size。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore.nn as nn\n",
    "import mindspore.ops as ops\n",
    "from mindspore.nn.probability.dpn import VAE\n",
    "\n",
    "IMAGE_SHAPE = (-1, 1, 32, 32)\n",
    "\n",
    "\n",
    "class Encoder(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(Encoder, self).__init__()\n",
    "        self.fc1 = nn.Dense(1024, 800)\n",
    "        self.fc2 = nn.Dense(800, 400)\n",
    "        self.relu = nn.ReLU()\n",
    "        self.flatten = nn.Flatten()\n",
    "\n",
    "    def construct(self, x):\n",
    "        x = self.flatten(x)\n",
    "        x = self.fc1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.fc2(x)\n",
    "        x = self.relu(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "class Decoder(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(Decoder, self).__init__()\n",
    "        self.fc1 = nn.Dense(400, 1024)\n",
    "        self.sigmoid = nn.Sigmoid()\n",
    "        self.reshape = ops.Reshape()\n",
    "\n",
    "    def construct(self, z):\n",
    "        z = self.fc1(z)\n",
    "        z = self.reshape(z, IMAGE_SHAPE)\n",
    "        z = self.sigmoid(z)\n",
    "        return z\n",
    "\n",
    "\n",
    "encoder = Encoder()\n",
    "decoder = Decoder()\n",
    "vae = VAE(encoder, decoder, hidden_size=400, latent_size=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ConditionalVAE\n",
    "\n",
    "类似地，ConditionalVAE与VAE的使用方法比较相近，不同的是，ConditionalVAE利用了数据集的标签信息，属于有监督学习算法，其生成效果一般会比VAE好。\n",
    "\n",
    "首先，先自定义encoder和decoder，并调用`mindspore.nn.probability.dpn.ConditionalVAE`接口来构建ConditionalVAE网络，这里的encoder和VAE的不同，因为需要传入数据集的标签信息；decoder和上述的一样。ConditionalVAE接口的传入则还需要传入数据集的标签类别个数，其余和VAE接口一样。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore.nn as nn\n",
    "import mindspore.ops as ops\n",
    "from mindspore.nn.probability.dpn import ConditionalVAE\n",
    "\n",
    "IMAGE_SHAPE = (-1, 1, 32, 32)\n",
    "\n",
    "\n",
    "class Encoder(nn.Cell):\n",
    "    def __init__(self, num_classes):\n",
    "        super(Encoder, self).__init__()\n",
    "        self.fc1 = nn.Dense(1024 + num_classes, 400)\n",
    "        self.relu = nn.ReLU()\n",
    "        self.flatten = nn.Flatten()\n",
    "        self.concat = ops.Concat(axis=1)\n",
    "        self.one_hot = nn.OneHot(depth=num_classes)\n",
    "\n",
    "    def construct(self, x, y):\n",
    "        x = self.flatten(x)\n",
    "        y = self.one_hot(y)\n",
    "        input_x = self.concat((x, y))\n",
    "        input_x = self.fc1(input_x)\n",
    "        input_x = self.relu(input_x)\n",
    "        return input_x\n",
    "\n",
    "\n",
    "class Decoder(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(Decoder, self).__init__()\n",
    "        self.fc1 = nn.Dense(400, 1024)\n",
    "        self.sigmoid = nn.Sigmoid()\n",
    "        self.reshape = ops.Reshape()\n",
    "\n",
    "    def construct(self, z):\n",
    "        z = self.fc1(z)\n",
    "        z = self.reshape(z, IMAGE_SHAPE)\n",
    "        z = self.sigmoid(z)\n",
    "        return z\n",
    "\n",
    "\n",
    "encoder = Encoder(num_classes=10)\n",
    "decoder = Decoder()\n",
    "cvae = ConditionalVAE(encoder, decoder, hidden_size=400, latent_size=20, num_classes=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "加载数据集，我们可以使用Mnist数据集，具体的数据加载和预处理过程可以参考这里[实现一个图片分类应用](https://www.mindspore.cn/tutorial/training/zh-CN/master/quick_start/quick_start.html)，这里会用到create_dataset函数创建数据迭代器。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore.dataset as ds\n",
    "from mindspore.common.initializer import Normal\n",
    "from mindspore.dataset.vision import Inter\n",
    "from mindspore import dtype as mstype\n",
    "import mindspore.dataset.vision.c_transforms as CV\n",
    "import mindspore.dataset.transforms.c_transforms as C\n",
    "\n",
    "def create_dataset(data_path, batch_size=32, repeat_size=1,\n",
    "                   num_parallel_workers=1):\n",
    "    \"\"\" create dataset for train or test\n",
    "    Args:\n",
    "        data_path: Data path\n",
    "        batch_size: The number of data records in each group\n",
    "        repeat_size: The number of replicated data records\n",
    "        num_parallel_workers: The number of parallel workers\n",
    "    \"\"\"\n",
    "    # define dataset\n",
    "    mnist_ds = ds.MnistDataset(data_path)\n",
    "\n",
    "    # define operation parameters\n",
    "    resize_height, resize_width = 32, 32\n",
    "    rescale = 1.0 / 255.0\n",
    "    shift = 0.0\n",
    "    rescale_nml = 1 / 0.3081\n",
    "    shift_nml = -1 * 0.1307 / 0.3081\n",
    "\n",
    "    # define map operations\n",
    "    resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)  # Resize images to (32, 32)\n",
    "    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) # normalize images\n",
    "    rescale_op = CV.Rescale(rescale, shift) # rescale images\n",
    "    hwc2chw_op = CV.HWC2CHW() # change shape from (height, width, channel) to (channel, height, width) to fit network.\n",
    "    type_cast_op = C.TypeCast(mstype.int32) # change data type of label to int32 to fit network\n",
    "\n",
    "    # apply map operations on images\n",
    "    mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns=\"label\", num_parallel_workers=num_parallel_workers)\n",
    "    mnist_ds = mnist_ds.map(operations=resize_op, input_columns=\"image\", num_parallel_workers=num_parallel_workers)\n",
    "    mnist_ds = mnist_ds.map(operations=rescale_op, input_columns=\"image\", num_parallel_workers=num_parallel_workers)\n",
    "    mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns=\"image\", num_parallel_workers=num_parallel_workers)\n",
    "    mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns=\"image\", num_parallel_workers=num_parallel_workers)\n",
    "\n",
    "    # apply DatasetOps\n",
    "    buffer_size = 10000\n",
    "    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)  # 10000 as in LeNet train script\n",
    "    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)\n",
    "    mnist_ds = mnist_ds.repeat(repeat_size)\n",
    "\n",
    "    return mnist_ds\n",
    "\n",
    "image_path = \"./MNIST/train\"\n",
    "ds_train = create_dataset(image_path, 128, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来，需要用到infer接口进行VAE网络的变分推断。\n",
    "\n",
    "## 概率推断算法\n",
    "\n",
    "调用ELBO接口（`mindspore.nn.probability.infer.ELBO`）来定义VAE网络的损失函数，调用`WithLossCell`封装VAE网络和损失函数，并定义优化器，之后传入SVI接口（`mindspore.nn.probability.infer.SVI`）。SVI的`run`函数可理解为VAE网络的训练，可以指定训练的`epochs`，返回结果为训练好的网络；`get_train_loss`函数可以返回训练好后模型的loss。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore.nn.probability.infer import ELBO, SVI\n",
    "\n",
    "net_loss = ELBO(latent_prior='Normal', output_prior='Normal')\n",
    "net_with_loss = nn.WithLossCell(vae, net_loss)\n",
    "optimizer = nn.Adam(params=vae.trainable_params(), learning_rate=0.001)\n",
    "\n",
    "vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer)\n",
    "vae = vi.run(train_dataset=ds_train, epochs=10)\n",
    "trained_loss = vi.get_train_loss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最后，得到训练好的VAE网络后，我们可以使用`vae.generate_sample`生成新样本，需要传入待生成样本的个数，及生成样本的shape，shape需要保持和原数据集中的样本shape一样；当然，我们也可以使用`vae.reconstruct_sample`重构原来数据集中的样本，来测试VAE网络的重建能力。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The shape of the generated sample is  (64, 1, 32, 32)\n"
     ]
    }
   ],
   "source": [
    "generated_sample = vae.generate_sample(64, IMAGE_SHAPE)\n",
    "for sample in ds_train.create_dict_iterator():\n",
    "    sample_x = Tensor(sample['image'], dtype=mstype.float32)\n",
    "    reconstructed_sample = vae.reconstruct_sample(sample_x)\n",
    "print('The shape of the generated sample is ', generated_sample.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ConditionalVAE训练过程和VAE的过程类似，但需要注意的是使用训练好的ConditionalVAE网络生成新样本和重建新样本时，需要输入标签信息，例如下面生成的新样本就是64个0-7的数字。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The shape of the generated sample is  (64, 1, 32, 32)\n"
     ]
    }
   ],
   "source": [
    "sample_label = Tensor([i for i in range(0, 8)] * 8, dtype=mstype.int32)\n",
    "generated_sample = cvae.generate_sample(sample_label, 64, IMAGE_SHAPE)\n",
    "for sample in ds_train.create_dict_iterator():\n",
    "    sample_x = Tensor(sample['image'], dtype=mstype.float32)\n",
    "    sample_y = Tensor(sample['label'], dtype=mstype.int32)\n",
    "    reconstructed_sample = cvae.reconstruct_sample(sample_x, sample_y)\n",
    "print('The shape of the generated sample is ', generated_sample.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如果希望新生成的样本更好，更清晰，用户可以自己定义更复杂的encoder和decoder，这里的示例只用了两层全连接层，仅供示例的指导。\n",
    "\n",
    "## 贝叶斯层\n",
    "\n",
    "下面的范例使用MindSpore的`nn.probability.bnn_layers`中的API实现BNN图片分类模型。MindSpore的`nn.probability.bnn_layers`中的API包括`NormalPrior`，`NormalPosterior`，`ConvReparam`，`DenseReparam`和`WithBNNLossCell`。BNN与DNN的最大区别在于，BNN层的weight和bias不再是确定的值，而是服从一个分布。其中，`NormalPrior`，`NormalPosterior`分别用来生成服从正态分布的先验分布和后验分布；`ConvReparam`和`DenseReparam`分别是使用reparameteration方法实现的贝叶斯卷积层和全连接层；`WithBNNLossCell`是用来封装BNN和损失函数的。\n",
    "\n",
    "如何使用`nn.probability.bnn_layers`中的API构建贝叶斯神经网络并实现图片分类，可以参考教程[使用贝叶斯网络](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/apply_deep_probability_programming.html#id3)。\n",
    "\n",
    "## 贝叶斯转换\n",
    "\n",
    "对于不熟悉贝叶斯模型的研究人员，MDP提供了贝叶斯转换接口（`mindspore.nn.probability.transform`），支持DNN (Deep Neural Network)模型一键转换成BNN (Bayesian Neural Network)模型。\n",
    "\n",
    "其中的模型转换API`TransformToBNN`的`__init__`函数定义如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TransformToBNN:\n",
    "    def __init__(self, trainable_dnn, dnn_factor=1, bnn_factor=1):\n",
    "        net_with_loss = trainable_dnn.network\n",
    "        self.optimizer = trainable_dnn.optimizer\n",
    "        self.backbone = net_with_loss.backbone_network\n",
    "        self.loss_fn = getattr(net_with_loss, \"_loss_fn\")\n",
    "        self.dnn_factor = dnn_factor\n",
    "        self.bnn_factor = bnn_factor\n",
    "        self.bnn_loss_file = None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "参数`trainable_bnn`是经过`TrainOneStepCell`包装的可训练DNN模型，`dnn_factor`和`bnn_factor`分别为由损失函数计算得到的网络整体损失的系数和每个贝叶斯层的KL散度的系数。\n",
    "API`TransformToBNN`主要实现了两个功能：\n",
    "\n",
    "- 功能一：转换整个模型\n",
    "\n",
    "  `transform_to_bnn_model`方法可以将整个DNN模型转换为BNN模型。其定义如下：\n",
    "\n",
    "  ```python\n",
    "    def transform_to_bnn_model(self,\n",
    "                               get_dense_args=lambda dp: {\"in_channels\": dp.in_channels, \"has_bias\": dp.has_bias,\n",
    "                                                          \"out_channels\": dp.out_channels, \"activation\": dp.activation},\n",
    "                               get_conv_args=lambda dp: {\"in_channels\": dp.in_channels, \"out_channels\": dp.out_channels,\n",
    "                                                         \"pad_mode\": dp.pad_mode, \"kernel_size\": dp.kernel_size,\n",
    "                                                         \"stride\": dp.stride, \"has_bias\": dp.has_bias,\n",
    "                                                         \"padding\": dp.padding, \"dilation\": dp.dilation,\n",
    "                                                         \"group\": dp.group},\n",
    "                               add_dense_args=None,\n",
    "                               add_conv_args=None):\n",
    "        r\"\"\"\n",
    "        Transform the whole DNN model to BNN model, and wrap BNN model by TrainOneStepCell.\n",
    "\n",
    "        Args:\n",
    "            get_dense_args (function): The arguments gotten from the DNN full connection layer. Default: lambda dp:\n",
    "                {\"in_channels\": dp.in_channels, \"out_channels\": dp.out_channels, \"has_bias\": dp.has_bias}.\n",
    "            get_conv_args (function): The arguments gotten from the DNN convolutional layer. Default: lambda dp:\n",
    "                {\"in_channels\": dp.in_channels, \"out_channels\": dp.out_channels, \"pad_mode\": dp.pad_mode,\n",
    "                \"kernel_size\": dp.kernel_size, \"stride\": dp.stride, \"has_bias\": dp.has_bias}.\n",
    "            add_dense_args (dict): The new arguments added to BNN full connection layer. Default: {}.\n",
    "            add_conv_args (dict): The new arguments added to BNN convolutional layer. Default: {}.\n",
    "\n",
    "        Returns:\n",
    "            Cell, a trainable BNN model wrapped by TrainOneStepCell.\n",
    "       \"\"\"\n",
    "\n",
    "  ```\n",
    "\n",
    "  参数`get_dense_args`指定从DNN模型的全连接层中获取哪些参数，默认值是DNN模型的全连接层和BNN的全连接层所共有的参数，参数具体的含义可以参考[API说明文档](https://www.mindspore.cn/doc/api_python/zh-CN/master/mindspore/nn/mindspore.nn.Dense.html)；`get_conv_args`指定从DNN模型的卷积层中获取哪些参数，默认值是DNN模型的卷积层和BNN的卷积层所共有的参数，参数具体的含义可以参考[API说明文档](https://www.mindspore.cn/doc/api_python/zh-CN/master/mindspore/nn/mindspore.nn.Conv2d.html)；参数`add_dense_args`和`add_conv_args`分别指定了要为BNN层指定哪些新的参数值。需要注意的是，`add_dense_args`中的参数不能与`get_dense_args`重复，`add_conv_args`和`get_conv_args`也是如此。\n",
    "\n",
    "- 功能二：转换指定类型的层\n",
    "\n",
    "  `transform_to_bnn_layer`方法可以将DNN模型中指定类型的层（`nn.Dense`或者`nn.Conv2d`）转换为对应的贝叶斯层。其定义如下：\n",
    "\n",
    "  ```python\n",
    "   def transform_to_bnn_layer(self, dnn_layer, bnn_layer, get_args=None, add_args=None):\n",
    "        r\"\"\"\n",
    "        Transform a specific type of layers in DNN model to corresponding BNN layer.\n",
    "\n",
    "        Args:\n",
    "            dnn_layer_type (Cell): The type of DNN layer to be transformed to BNN layer. The optional values are\n",
    "            nn.Dense, nn.Conv2d.\n",
    "            bnn_layer_type (Cell): The type of BNN layer to be transformed to. The optional values are\n",
    "                DenseReparameterization, ConvReparameterization.\n",
    "            get_args (dict): The arguments gotten from the DNN layer. Default: None.\n",
    "            add_args (dict): The new arguments added to BNN layer. Default: None.\n",
    "\n",
    "        Returns:\n",
    "            Cell, a trainable model wrapped by TrainOneStepCell, whose sprcific type of layer is transformed to the corresponding bayesian layer.\n",
    "        \"\"\"\n",
    "  ```\n",
    "\n",
    "  参数`dnn_layer`指定将哪个类型的DNN层转换成BNN层，`bnn_layer`指定DNN层将转换成哪个类型的BNN层，`get_args`和`add_args`分别指定从DNN层中获取哪些参数和要为BNN层的哪些参数重新赋值。\n",
    "\n",
    "如何在MindSpore中使用API`TransformToBNN`可以参考教程[DNN一键转换成BNN](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/apply_deep_probability_programming.html#dnnbnn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 贝叶斯工具箱\n",
    "\n",
    "贝叶斯神经网络的优势之一就是可以获取不确定性，MDP在上层提供了不确定性估计的工具箱（`mindspore.nn.probability.toolbox`），用户可以很方便地使用该工具箱计算不确定性。不确定性意味着深度学习模型对预测结果的不确定程度。目前，大多数深度学习算法只能给出高置信度的预测结果，而不能判断预测结果的确定性，不确定性主要有两种类型：偶然不确定性和认知不确定性。\n",
    "\n",
    "- 偶然不确定性（Aleatoric Uncertainty）：描述数据中的内在噪声，即无法避免的误差，这个现象不能通过增加采样数据来削弱。\n",
    "- 认知不确定性（Epistemic Uncertainty）：模型自身对输入数据的估计可能因为训练不佳、训练数据不够等原因而不准确，可以通过增加训练数据等方式来缓解。\n",
    "\n",
    "不确定性评估工具箱的接口如下：\n",
    "\n",
    "- `model`：待评估不确定性的已训练好的模型。\n",
    "- `train_dataset`：用于训练的数据集，迭代器类型。\n",
    "- `task_type`：模型的类型，字符串，输入“regression”或者“classification”。\n",
    "- `num_classes`：如果是分类模型，需要指定类别的标签数量。\n",
    "- `epochs`：用于训练不确定模型的迭代数。\n",
    "- `epi_uncer_model_path`：用于存储或加载计算认知不确定性的模型的路径。\n",
    "- `ale_uncer_model_path`：用于存储或加载计算偶然不确定性的模型的路径。\n",
    "- `save_model`：布尔类型，是否需要存储模型。\n",
    "\n",
    "在使用前，需要先训练好模型，以LeNet5为例，使用方式如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1 step: 1875, loss is 0.07030643\n",
      "epoch: 1 step: 1875, loss is 0.10861007\n",
      "The shape of epistemic uncertainty is  (32, 10)\n",
      "The shape of aleatoric uncertainty is  (32,)\n"
     ]
    }
   ],
   "source": [
    "import mindspore.nn as nn\n",
    "from mindspore import Tensor\n",
    "from mindspore.nn.probability.toolbox.uncertainty_evaluation import UncertaintyEvaluation\n",
    "from mindspore import load_checkpoint, load_param_into_net\n",
    "\n",
    "\n",
    "class LeNet5(nn.Cell):\n",
    "    \"\"\"Lenet network structure.\"\"\"\n",
    "    # define the operator required\n",
    "    def __init__(self, num_class=10, num_channel=1):\n",
    "        super(LeNet5, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')\n",
    "        self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))\n",
    "        self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))\n",
    "        self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))\n",
    "        self.relu = nn.ReLU()\n",
    "        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        self.flatten = nn.Flatten()\n",
    "\n",
    "    # use the preceding operators to construct networks\n",
    "    def construct(self, x):\n",
    "        x = self.max_pool2d(self.relu(self.conv1(x)))\n",
    "        x = self.max_pool2d(self.relu(self.conv2(x)))\n",
    "        x = self.flatten(x)\n",
    "        x = self.relu(self.fc1(x))\n",
    "        x = self.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    # get trained model\n",
    "    network = LeNet5()\n",
    "    param_dict = load_checkpoint('checkpoint_lenet.ckpt')\n",
    "    load_param_into_net(network, param_dict)\n",
    "    # get train and eval dataset\n",
    "    ds_train = create_dataset('MNIST/train')\n",
    "    ds_eval = create_dataset('MNIST/test')\n",
    "    evaluation = UncertaintyEvaluation(model=network,\n",
    "                                       train_dataset=ds_train,\n",
    "                                       task_type='classification',\n",
    "                                       num_classes=10,\n",
    "                                       epochs=1,\n",
    "                                       epi_uncer_model_path=None,\n",
    "                                       ale_uncer_model_path=None,\n",
    "                                       save_model=False)\n",
    "    for eval_data in ds_eval.create_dict_iterator():\n",
    "        eval_data = Tensor(eval_data['image'], mstype.float32)\n",
    "        epistemic_uncertainty = evaluation.eval_epistemic_uncertainty(eval_data)\n",
    "        aleatoric_uncertainty = evaluation.eval_aleatoric_uncertainty(eval_data)\n",
    "    print('The shape of epistemic uncertainty is ', epistemic_uncertainty.shape)\n",
    "    print('The shape of aleatoric uncertainty is ', aleatoric_uncertainty.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`eval_epistemic_uncertainty`计算的是认知不确定性，也叫模型不确定性，对于每一个样本的每个预测标签都会有一个不确定值；`eval_aleatoric_uncertainty`计算的是偶然不确定性，也叫数据不确定性，对于每一个样本都会有一个不确定值。\n",
    "\n",
    "uncertainty的值大于等于0，越大表示不确定性越高。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
